Sparkify Project Workspace

This workspace contains a tiny subset (128MB) of the full dataset available (12GB). Feel free to use this workspace to build your project, or to explore a smaller subset with Spark before deploying your cluster on the cloud. Instructions for setting up your Spark cluster is included in the last lesson of the Extracurricular Spark Course content.

You can follow the steps below to guide your data analysis and model building portion of this project.

Introduction
In this project, from the user operation log of music distribution service SPOTIFY Build a classification model that identifies users who are likely to opt out without being satisfied with the service.

Software Requirements

  • Python3
  • Pyspark 2.4.3
  • Plotly 2.0.15

Data Requirements

  • Sparkify user operation log From October 1, 2018 to early December
  • filename:"mini_sparkify_event_data.json"
  • The data contains personal information and can only be accessed by teachers and students.

Steps

  • Load and Clean Dataset
  • Exploratory Data Analysis
    • Define Churn(Define Objective variable)
    • Explore Data
  • Feature Engineering
  • modeling
    • logistic Regression
    • GBTClassifier
    • Randam forest
  • Conclusion
In [1]:
# import libraries
import pyspark
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import isnan,isnull,count, when, col, desc, udf, col, sort_array, asc, avg ,datediff,weekofyear
from pyspark.sql.functions import to_date, from_unixtime
from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import sum as Fsum
from pyspark.sql.functions import to_date
from pyspark.sql import Window

from pyspark.sql.types import DateType

import datetime

import numpy as np
import pandas as pd
#%matplotlib inline
import matplotlib.pyplot as plt
In [2]:
import plotly.offline as offline
import plotly.graph_objs as go
offline.init_notebook_mode()
In [ ]:
 
In [3]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import MinMaxScaler,StandardScaler, VectorAssembler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
In [4]:
%%HTML
<style>
    div#notebook-container    { width: 100%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
In [5]:
# create a Spark session
spark = SparkSession \
    .builder \
    .appName("Sparkify Project") \
    .getOrCreate()
In [6]:
spark.sparkContext.getConf().getAll()
Out[6]:
[('spark.app.name', 'Sparkify Project'),
 ('spark.driver.port', '37955'),
 ('spark.rdd.compress', 'True'),
 ('spark.serializer.objectStreamReset', '100'),
 ('spark.master', 'local[*]'),
 ('spark.executor.id', 'driver'),
 ('spark.submit.deployMode', 'client'),
 ('spark.driver.host',
  'instance-2.asia-northeast2-a.c.nice-carving-276211.internal'),
 ('spark.ui.showConsoleProgress', 'true'),
 ('spark.app.id', 'local-1599562707554')]
In [7]:
spark
Out[7]:

SparkSession - in-memory

SparkContext

Spark UI

Version
v2.4.6
Master
local[*]
AppName
Sparkify Project

Load and Clean Dataset

In this workspace, the mini-dataset file is mini_sparkify_event_data.json. Load and clean the dataset, checking for invalid or missing data - for example, records without userids or sessionids.

In [8]:
path = "mini_sparkify_event_data.json"
user_log = spark.read.json(path)
In [9]:
user_log.show(n=2)
+----------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+
|          artist|     auth|firstName|gender|itemInSession|lastName|   length|level|            location|method|    page| registration|sessionId|     song|status|           ts|           userAgent|userId|
+----------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+
|  Martha Tilston|Logged In|    Colin|     M|           50| Freeman|277.89016| paid|     Bakersfield, CA|   PUT|NextSong|1538173362000|       29|Rockpools|   200|1538352117000|Mozilla/5.0 (Wind...|    30|
|Five Iron Frenzy|Logged In|    Micah|     M|           79|    Long|236.09424| free|Boston-Cambridge-...|   PUT|NextSong|1538331630000|        8|   Canada|   200|1538352180000|"Mozilla/5.0 (Win...|     9|
+----------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+---------+------+-------------+--------------------+------+
only showing top 2 rows

Clean Dataset

In [10]:
print("row            :",user_log.select(col('*')).count())
print("sessionId Nan :",user_log.filter(isnan(col('sessionId'))).count())
print("sessionId NULL :",user_log.filter(isnull(col('sessionId'))).count())
print("sessionId \"\"   :",user_log.filter((col('sessionId'))=="").count())
print("userId Nan   :",user_log.filter(isnan(col('userId'))).count())
print("userId Null  :",user_log.filter(isnull(col('userId'))).count())
print("userId \"\"    :",user_log.filter((col('userId'))=="").count())
row            : 286500
sessionId Nan : 0
sessionId NULL : 0
sessionId ""   : 0
userId Nan   : 0
userId Null  : 0
userId ""    : 8346
In [11]:
user_log=user_log.filter((col('userId'))!="")
In [12]:
print("row            :",user_log.select(col('*')).count())
print("sessionId Nan :",user_log.filter(isnan(col('sessionId'))).count())
print("sessionId NULL :",user_log.filter(isnull(col('sessionId'))).count())
print("sessionId \"\"   :",user_log.filter((col('sessionId'))=="").count())
print("userId Nan   :",user_log.filter(isnan(col('userId'))).count())
print("userId Null  :",user_log.filter(isnull(col('userId'))).count())
print("userId \"\"    :",user_log.filter((col('userId'))=="").count())
row            : 278154
sessionId Nan : 0
sessionId NULL : 0
sessionId ""   : 0
userId Nan   : 0
userId Null  : 0
userId ""    : 0

Exploratory Data Analysis

When you're working with the full dataset, perform EDA by loading a small subset of the data and doing basic manipulations within Spark. In this workspace, you are already provided a small subset of data you can explore.

Define Churn

Once you've done some preliminary analysis, create a column Churn to use as the label for your model. I suggest using the Cancellation Confirmation events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the Downgrade events.

Explore Data

Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

In [13]:
user_log.printSchema()
root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)

In [14]:
explore= user_log.select('status').dropDuplicates().collect()
set(explore)
Out[14]:
{Row(status=200), Row(status=307), Row(status=404)}
In [15]:
explore= user_log.select('level').dropDuplicates().collect()
set(explore)
Out[15]:
{Row(level='free'), Row(level='paid')}
In [16]:
explore= user_log.select('auth').collect()
set(explore)
Out[16]:
{Row(auth='Cancelled'), Row(auth='Logged In')}
In [17]:
explore= user_log.select('method').collect()
set(explore)
Out[17]:
{Row(method='GET'), Row(method='PUT')}
In [18]:
explore= user_log.select('userAgent').collect()
set(explore)
Out[18]:
{Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10) AppleWebKit/600.1.3 (KHTML, like Gecko) Version/8.0 Safari/600.1.3"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10) AppleWebKit/600.1.8 (KHTML, like Gecko) Version/8.0 Safari/600.1.8"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_6_8) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_5) AppleWebKit/537.77.4 (KHTML, like Gecko) Version/6.1.5 Safari/537.77.4"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.94 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_2) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_2) AppleWebKit/537.74.9 (KHTML, like Gecko) Version/7.0.2 Safari/537.74.9"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_2) AppleWebKit/537.75.14 (KHTML, like Gecko) Version/7.0.3 Safari/537.75.14"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.76.4 (KHTML, like Gecko) Version/7.0.4 Safari/537.76.4"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.153 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.94 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.77.4 (KHTML, like Gecko) Version/7.0.5 Safari/537.77.4"'),
 Row(userAgent='"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.78.2 (KHTML, like Gecko) Version/7.0.6 Safari/537.78.2"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 5.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 5.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.153 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.103 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2062.94 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.2; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.2; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"'),
 Row(userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"'),
 Row(userAgent='"Mozilla/5.0 (iPad; CPU OS 7_1_2 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D257 Safari/9537.53"'),
 Row(userAgent='"Mozilla/5.0 (iPhone; CPU iPhone OS 7_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D167 Safari/9537.53"'),
 Row(userAgent='"Mozilla/5.0 (iPhone; CPU iPhone OS 7_1_1 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D201 Safari/9537.53"'),
 Row(userAgent='"Mozilla/5.0 (iPhone; CPU iPhone OS 7_1_2 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D257 Safari/9537.53"'),
 Row(userAgent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10.6; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10.7; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10.8; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.0; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; Trident/7.0; rv:11.0) like Gecko'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:24.0) Gecko/20100101 Firefox/24.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:30.0) Gecko/20100101 Firefox/30.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; WOW64; rv:32.0) Gecko/20100101 Firefox/32.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.1; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.2; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (Windows NT 6.3; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (X11; Linux x86_64; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (X11; Ubuntu; Linux i686; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:31.0) Gecko/20100101 Firefox/31.0'),
 Row(userAgent='Mozilla/5.0 (compatible; MSIE 10.0; Windows NT 6.1; WOW64; Trident/6.0)'),
 Row(userAgent='Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; Trident/5.0)'),
 Row(userAgent='Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; WOW64; Trident/5.0)')}
In [19]:
explore= user_log.select('location').collect()
set(explore)
Out[19]:
{Row(location='Albany, OR'),
 Row(location='Albany-Schenectady-Troy, NY'),
 Row(location='Alexandria, LA'),
 Row(location='Allentown-Bethlehem-Easton, PA-NJ'),
 Row(location='Anchorage, AK'),
 Row(location='Atlanta-Sandy Springs-Roswell, GA'),
 Row(location='Atlantic City-Hammonton, NJ'),
 Row(location='Austin-Round Rock, TX'),
 Row(location='Bakersfield, CA'),
 Row(location='Baltimore-Columbia-Towson, MD'),
 Row(location='Billings, MT'),
 Row(location='Birmingham-Hoover, AL'),
 Row(location='Boston-Cambridge-Newton, MA-NH'),
 Row(location='Boulder, CO'),
 Row(location='Bozeman, MT'),
 Row(location='Bridgeport-Stamford-Norwalk, CT'),
 Row(location='Buffalo-Cheektowaga-Niagara Falls, NY'),
 Row(location='Charlotte-Concord-Gastonia, NC-SC'),
 Row(location='Chicago-Naperville-Elgin, IL-IN-WI'),
 Row(location='Cincinnati, OH-KY-IN'),
 Row(location='Cleveland-Elyria, OH'),
 Row(location='Colorado Springs, CO'),
 Row(location='Columbus, GA-AL'),
 Row(location='Concord, NH'),
 Row(location='Cookeville, TN'),
 Row(location='Corpus Christi, TX'),
 Row(location='Dallas-Fort Worth-Arlington, TX'),
 Row(location='Danville, VA'),
 Row(location='Dayton, OH'),
 Row(location='Deltona-Daytona Beach-Ormond Beach, FL'),
 Row(location='Denver-Aurora-Lakewood, CO'),
 Row(location='Detroit-Warren-Dearborn, MI'),
 Row(location='Dubuque, IA'),
 Row(location='Duluth, MN-WI'),
 Row(location='Fairbanks, AK'),
 Row(location='Flint, MI'),
 Row(location='Gainesville, FL'),
 Row(location='Greensboro-High Point, NC'),
 Row(location='Greenville-Anderson-Mauldin, SC'),
 Row(location='Hagerstown-Martinsburg, MD-WV'),
 Row(location='Hartford-West Hartford-East Hartford, CT'),
 Row(location='Houston-The Woodlands-Sugar Land, TX'),
 Row(location='Indianapolis-Carmel-Anderson, IN'),
 Row(location='Ionia, MI'),
 Row(location='Jackson, MS'),
 Row(location='Jacksonville, FL'),
 Row(location='Jacksonville, NC'),
 Row(location='Kankakee, IL'),
 Row(location='Kansas City, MO-KS'),
 Row(location='Kingsport-Bristol-Bristol, TN-VA'),
 Row(location='Las Vegas-Henderson-Paradise, NV'),
 Row(location='Laurel, MS'),
 Row(location='Lexington-Fayette, KY'),
 Row(location='Little Rock-North Little Rock-Conway, AR'),
 Row(location='Logan, UT-ID'),
 Row(location='London, KY'),
 Row(location='Los Angeles-Long Beach-Anaheim, CA'),
 Row(location='Louisville/Jefferson County, KY-IN'),
 Row(location='Manchester-Nashua, NH'),
 Row(location='McAllen-Edinburg-Mission, TX'),
 Row(location='Memphis, TN-MS-AR'),
 Row(location='Miami-Fort Lauderdale-West Palm Beach, FL'),
 Row(location='Milwaukee-Waukesha-West Allis, WI'),
 Row(location='Minneapolis-St. Paul-Bloomington, MN-WI'),
 Row(location='Monroe, LA'),
 Row(location='Montgomery, AL'),
 Row(location='Morgantown, WV'),
 Row(location='Muncie, IN'),
 Row(location='Myrtle Beach-Conway-North Myrtle Beach, SC-NC'),
 Row(location='Napa, CA'),
 Row(location='New Haven-Milford, CT'),
 Row(location='New Philadelphia-Dover, OH'),
 Row(location='New York-Newark-Jersey City, NY-NJ-PA'),
 Row(location='North Wilkesboro, NC'),
 Row(location='Oklahoma City, OK'),
 Row(location='Omaha-Council Bluffs, NE-IA'),
 Row(location='Orlando-Kissimmee-Sanford, FL'),
 Row(location='Oxnard-Thousand Oaks-Ventura, CA'),
 Row(location='Philadelphia-Camden-Wilmington, PA-NJ-DE-MD'),
 Row(location='Phoenix-Mesa-Scottsdale, AZ'),
 Row(location='Pittsburgh, PA'),
 Row(location='Pontiac, IL'),
 Row(location='Port St. Lucie, FL'),
 Row(location='Portland-Vancouver-Hillsboro, OR-WA'),
 Row(location='Price, UT'),
 Row(location='Providence-Warwick, RI-MA'),
 Row(location='Quincy, IL-MO'),
 Row(location='Raleigh, NC'),
 Row(location='Riverside-San Bernardino-Ontario, CA'),
 Row(location='Roanoke, VA'),
 Row(location='Sacramento--Roseville--Arden-Arcade, CA'),
 Row(location='Salinas, CA'),
 Row(location='San Antonio-New Braunfels, TX'),
 Row(location='San Diego-Carlsbad, CA'),
 Row(location='San Francisco-Oakland-Hayward, CA'),
 Row(location='San Jose-Sunnyvale-Santa Clara, CA'),
 Row(location='Santa Maria-Santa Barbara, CA'),
 Row(location='Scranton--Wilkes-Barre--Hazleton, PA'),
 Row(location='Seattle-Tacoma-Bellevue, WA'),
 Row(location='Selma, AL'),
 Row(location='Spokane-Spokane Valley, WA'),
 Row(location='St. Louis, MO-IL'),
 Row(location='Sterling, IL'),
 Row(location='Stockton-Lodi, CA'),
 Row(location='Syracuse, NY'),
 Row(location='Tallahassee, FL'),
 Row(location='Tampa-St. Petersburg-Clearwater, FL'),
 Row(location='Troy, AL'),
 Row(location='Truckee-Grass Valley, CA'),
 Row(location='Vineland-Bridgeton, NJ'),
 Row(location='Virginia Beach-Norfolk-Newport News, VA-NC'),
 Row(location='Washington-Arlington-Alexandria, DC-VA-MD-WV'),
 Row(location='Wilson, NC'),
 Row(location='Winston-Salem, NC')}
In [20]:
explore= user_log.select('page').dropDuplicates().collect()
set(explore)
Out[20]:
{Row(page='About'),
 Row(page='Add Friend'),
 Row(page='Add to Playlist'),
 Row(page='Cancel'),
 Row(page='Cancellation Confirmation'),
 Row(page='Downgrade'),
 Row(page='Error'),
 Row(page='Help'),
 Row(page='Home'),
 Row(page='Logout'),
 Row(page='NextSong'),
 Row(page='Roll Advert'),
 Row(page='Save Settings'),
 Row(page='Settings'),
 Row(page='Submit Downgrade'),
 Row(page='Submit Upgrade'),
 Row(page='Thumbs Down'),
 Row(page='Thumbs Up'),
 Row(page='Upgrade')}
In [21]:
user_log.filter("page = 'Submit Downgrade'").show(3)
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+----------------+-------------+---------+----+------+-------------+--------------------+------+
|artist|     auth|firstName|gender|itemInSession|lastName|length|level|            location|method|            page| registration|sessionId|song|status|           ts|           userAgent|userId|
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+----------------+-------------+---------+----+------+-------------+--------------------+------+
|  null|Logged In|     Kael|     M|           47|   Baker|  null| paid|Kingsport-Bristol...|   PUT|Submit Downgrade|1533102330000|      249|null|   307|1538393619000|"Mozilla/5.0 (Mac...|   131|
|  null|Logged In|   Calvin|     M|           17|Marshall|  null| paid|      Pittsburgh, PA|   PUT|Submit Downgrade|1537120757000|      313|null|   307|1538516445000|"Mozilla/5.0 (Mac...|    38|
|  null|Logged In|  Kaylenn|     F|          354| Jenkins|  null| paid|           Price, UT|   PUT|Submit Downgrade|1535903878000|      479|null|   307|1538835479000|"Mozilla/5.0 (Mac...|   141|
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+----------------+-------------+---------+----+------+-------------+--------------------+------+
only showing top 3 rows

In [22]:
user_log.filter("page = 'Cancellation Confirmation'").show(3)
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+--------------------+-------------+---------+----+------+-------------+--------------------+------+
|artist|     auth|firstName|gender|itemInSession|lastName|length|level|            location|method|                page| registration|sessionId|song|status|           ts|           userAgent|userId|
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+--------------------+-------------+---------+----+------+-------------+--------------------+------+
|  null|Cancelled|   Adriel|     M|          104| Mendoza|  null| paid|  Kansas City, MO-KS|   GET|Cancellation Conf...|1535623466000|      514|null|   200|1538943990000|"Mozilla/5.0 (Mac...|    18|
|  null|Cancelled|    Diego|     M|           56|   Mckee|  null| paid|Phoenix-Mesa-Scot...|   GET|Cancellation Conf...|1537167593000|      540|null|   200|1539033046000|"Mozilla/5.0 (iPh...|    32|
|  null|Cancelled|    Mason|     M|           10|    Hart|  null| free|  Corpus Christi, TX|   GET|Cancellation Conf...|1533157139000|      174|null|   200|1539318918000|"Mozilla/5.0 (Mac...|   125|
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+--------------------+-------------+---------+----+------+-------------+--------------------+------+
only showing top 3 rows

In [23]:
user_log.select(["userId", "firstname","itemInSession" ,"page", "level", "song"]).where(user_log.userId == "143").sort("ts").show(3)
+------+---------+-------------+--------+-----+--------------------+
|userId|firstname|itemInSession|    page|level|                song|
+------+---------+-------------+--------+-----+--------------------+
|   143|    Molly|            0|NextSong| free|Will You (Single ...|
|   143|    Molly|            1|NextSong| free|The Geeks Were Right|
|   143|    Molly|            2|NextSong| free|Pursuit Of Happin...|
+------+---------+-------------+--------+-----+--------------------+
only showing top 3 rows

Define Churn

In [24]:
def plot_hist(df,param,title,x_axis,y_axis):
    """
    plot histgram by plotly
    
    Input
        df    :datafram which include column 'Churn' and other parameter 
        param : parameter name
        title : graph title
        x_axis: x-axis name
        y_axis: y-axis name
    
    """
    Churn1 = go.Histogram(x=df.filter((col('Churn'))==1).toPandas()[param],name="Churn=1", opacity=0.5)
    Churn0 = go.Histogram(x=df.filter((col('Churn'))!=1).toPandas()[param],name="Churn=0", opacity=0.5)
    layout = go.Layout(
        title = title,
        xaxis = dict(title=x_axis),
        yaxis = dict(title=y_axis),
        bargap=0.2,
        bargroupgap=0.1,
        width=800,
        height=300
    )
    fig = dict(data=[Churn1,Churn0], layout=layout)
    offline.iplot(fig, filename=param)

Find Churn user from event "Churancellation Confirmation" and mark the flag "upgrade"

In [25]:
windowval = Window\
    .partitionBy("userId")\
    .orderBy(desc("ts"))\
    .rangeBetween(Window.unboundedPreceding, 0)

flag_Cancellation_event = udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
#flag_downgrade_event = udf(lambda x: 1 if x == "Submit Downgrade" else 2 if x=="Cancellation Confirmation" else 0, IntegerType())
In [26]:
user_log = user_log\
    .withColumn("upgrade", flag_Cancellation_event("page"))\
    .withColumn("Churn", Fsum("upgrade").over(windowval))

user_log.head()
Out[26]:
Row(artist=None, auth='Logged In', firstName='Darianna', gender='F', itemInSession=34, lastName='Carpenter', length=None, level='free', location='Bridgeport-Stamford-Norwalk, CT', method='PUT', page='Logout', registration=1538016340000, sessionId=187, song=None, status=307, ts=1542823952000, userAgent='"Mozilla/5.0 (iPhone; CPU iPhone OS 7_1_2 like Mac OS X) AppleWebKit/537.51.2 (KHTML, like Gecko) Version/7.0 Mobile/11D257 Safari/9537.53"', userId='100010', upgrade=0, Churn=0)
In [27]:
# add day column change  from timestamp to DateStrng 
get_day = udf(lambda x: datetime.datetime.fromtimestamp(x/1000.0).strftime('%Y%m%d'))
get_month = udf(lambda x: datetime.datetime.fromtimestamp(x/1000.0).strftime('%Y%m'))
user_log = user_log.withColumn("day", get_day('ts'))
user_log = user_log.withColumn("month", get_month('ts'))
user_log = user_log.withColumn("week", weekofyear(from_unixtime(user_log.ts / 1000.0)))
user_log.show(2)
+----------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+--------+-------------+---------+----+------+-------------+--------------------+------+-------+-----+--------+------+----+
|    artist|     auth|firstName|gender|itemInSession| lastName|   length|level|            location|method|    page| registration|sessionId|song|status|           ts|           userAgent|userId|upgrade|Churn|     day| month|week|
+----------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+--------+-------------+---------+----+------+-------------+--------------------+------+-------+-----+--------+------+----+
|      null|Logged In| Darianna|     F|           34|Carpenter|     null| free|Bridgeport-Stamfo...|   PUT|  Logout|1538016340000|      187|null|   307|1542823952000|"Mozilla/5.0 (iPh...|100010|      0|    0|20181121|201811|  47|
|Lily Allen|Logged In| Darianna|     F|           33|Carpenter|185.25995| free|Bridgeport-Stamfo...|   PUT|NextSong|1538016340000|      187|  22|   200|1542823951000|"Mozilla/5.0 (iPh...|100010|      0|    0|20181121|201811|  47|
+----------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+--------+-------------+---------+----+------+-------------+--------------------+------+-------+-----+--------+------+----+
only showing top 2 rows

only two month

In [28]:
user_log.agg({'day':'max'}).show()
user_log.agg({'day':'min'}).show()
+--------+
|max(day)|
+--------+
|20181203|
+--------+

+--------+
|min(day)|
+--------+
|20181001|
+--------+

We can see that it is unbalanced data from the following

In [29]:
user_Churn=user_log\
    .select(['userId','Churn'])\
    .dropDuplicates(['userId'])

plot_hist(df=user_Churn,param='Churn',title="Churn  ",x_axis="Churn",y_axis="number of user ") 

Explore Data

Select a potential variable as the explanatory variable. Some variables are processed so that they are considered to affect the target variable.

In [30]:
user_log.filter(user_log.userId=='143').orderBy('ts').show(3)
+--------------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+------+-------+-----+--------+------+----+
|              artist|     auth|firstName|gender|itemInSession|lastName|   length|level|            location|method|    page| registration|sessionId|                song|status|           ts|           userAgent|userId|upgrade|Churn|     day| month|week|
+--------------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+------+-------+-----+--------+------+----+
|              P.O.D.|Logged In|    Molly|     F|            0|Harrison|225.88036| free|Virginia Beach-No...|   PUT|NextSong|1534255113000|      142|Will You (Single ...|   200|1538401635000|"Mozilla/5.0 (Mac...|   143|      0|    1|20181001|201810|  40|
|           The Faint|Logged In|    Molly|     F|            1|Harrison|178.23302| free|Virginia Beach-No...|   PUT|NextSong|1534255113000|      142|The Geeks Were Right|   200|1538401860000|"Mozilla/5.0 (Mac...|   143|      0|    1|20181001|201810|  40|
|Kid Cudi / MGMT /...|Logged In|    Molly|     F|            2|Harrison|295.67955| free|Virginia Beach-No...|   PUT|NextSong|1534255113000|      142|Pursuit Of Happin...|   200|1538402038000|"Mozilla/5.0 (Mac...|   143|      0|    1|20181001|201810|  40|
+--------------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+------+-------+-----+--------+------+----+
only showing top 3 rows

gender

In [31]:
gender=user_log\
    .select(['userId','gender','Churn'])\
    .dropDuplicates(['userId'])\
    .replace(['F','M'],['0','1'],'gender')
gender = gender.withColumn('gender', gender.gender.cast('int'))


gender.show(10)

plot_hist(df=gender,param='gender',title="gender ",x_axis="gender",y_axis="number of user ") 
+------+------+-----+
|userId|gender|Churn|
+------+------+-----+
|100010|     0|    0|
|200002|     1|    0|
|   125|     1|    1|
|   124|     0|    0|
|    51|     1|    1|
|     7|     1|    0|
|    15|     1|    0|
|    54|     0|    1|
|   155|     0|    0|
|100014|     1|    1|
+------+------+-----+
only showing top 10 rows

level (Free or Paid)

In [32]:
level=user_log\
    .select(['userId','level','Churn'])\
    .dropDuplicates(['userId'])\
    .replace(['paid','free'],['0','1'],'level')
level= level.withColumn('level', level.level.cast('int'))

level.show(10)

plot_hist(df=level,param='level',title="level ",x_axis="level",y_axis="number of user ") 
+------+-----+-----+
|userId|level|Churn|
+------+-----+-----+
|100010|    1|    0|
|200002|    0|    0|
|   125|    1|    1|
|   124|    0|    0|
|    51|    0|    1|
|     7|    1|    0|
|    15|    0|    0|
|    54|    0|    1|
|   155|    0|    0|
|100014|    0|    1|
+------+-----+-----+
only showing top 10 rows

average songs per session

In [33]:
#user_log.groupby("userId","sessionId").count().orderBy("userId",user_log.day.cast("float")).groupby("userId").avg().orderBy("userId")
session_song=user_log\
    .filter(user_log.page == 'NextSong')\
    .select('page','userId','sessionId')\
    .groupby("userId","sessionId")\
    .agg({'page':'count'})\
    .withColumnRenamed('count(page)', 'count_song') \
    .sort(('userId'))
session_song.show(10)

session_song=session_song.groupby("userId")\
.agg({'count_song':'avg'})\
.withColumnRenamed('avg(count_song)', 'songs_per_session')



session_song=session_song\
    .join(user_Churn,on='userId',how="inner")
session_song.show(10)

plot_hist(df=session_song,param='songs_per_session',title="songs per session ",y_axis="number of user",x_axis="song count per session")    
+------+---------+----------+
|userId|sessionId|count_song|
+------+---------+----------+
|    10|      595|       381|
|    10|        9|        57|
|    10|     1414|        63|
|    10|     1047|        21|
|    10|     1592|        67|
|    10|     1981|        84|
|   100|      683|        23|
|   100|     1782|         4|
|   100|      369|        91|
|   100|     2428|        81|
+------+---------+----------+
only showing top 10 rows

+------+------------------+-----+
|userId| songs_per_session|Churn|
+------+------------------+-----+
|100010|39.285714285714285|    0|
|200002|              64.5|    0|
|   125|               8.0|    1|
|   124|145.67857142857142|    0|
|    51|             211.1|    1|
|     7|21.428571428571427|    0|
|    15|136.71428571428572|    0|
|    54| 81.17142857142858|    1|
|   155|136.66666666666666|    0|
|100014|42.833333333333336|    1|
+------+------------------+-----+
only showing top 10 rows

average songs per day

In [34]:
songs_in_day = user_log\
    .groupby("userId","day")\
    .count()\
    .orderBy("userId",user_log.day.cast("float"))
songs_in_day.show(n=10)

songs_in_day=songs_in_day.groupby("userId")\
    .agg({'count':'avg'})\
    .withColumnRenamed('avg(count)', 'songs_per_day') \
    .orderBy('userId')
#songs_in_day.show(n=10)

songs_in_day=songs_in_day\
    .join(user_Churn,on='userId',how="inner")
songs_in_day.show(10)

plot_hist(df=songs_in_day,param='songs_per_day',title="songs per day ",y_axis="number of user",x_axis="songs per day")  
+------+--------+-----+
|userId|     day|count|
+------+--------+-----+
|    10|20181008|   70|
|    10|20181018|  345|
|    10|20181019|  104|
|    10|20181029|   25|
|    10|20181103|   78|
|    10|20181116|    8|
|    10|20181115|   70|
|    10|20181119|   95|
|   100|20181002|  154|
|   100|20181004|  108|
+------+--------+-----+
only showing top 10 rows

+------+------------------+-----+
|userId|     songs_per_day|Churn|
+------+------------------+-----+
|100010| 54.42857142857143|    0|
|200002| 67.71428571428571|    0|
|   125|              11.0|    1|
|   124|146.21212121212122|    0|
|    51|189.53846153846155|    1|
|     7|            25.125|    0|
|    15|119.89473684210526|    0|
|    54|110.87096774193549|    1|
|   155|            125.25|    0|
|100014|51.666666666666664|    1|
+------+------------------+-----+
only showing top 10 rows

stddev songs per day

In [35]:
std_songs_in_day = user_log\
    .groupby("userId","day")\
    .count()\
    .orderBy("userId",user_log.day.cast("float"))
std_songs_in_day.show(n=10)

std_songs_in_day=std_songs_in_day.groupby("userId")\
    .agg({'count':'stddev_pop'})\
    .withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_day') \
    .orderBy('userId')
#songs_in_day.show(n=10)

std_songs_in_day=std_songs_in_day\
    .join(user_Churn,on='userId',how="inner")
std_songs_in_day.show(10)

plot_hist(df=std_songs_in_day,param='stddev_songs_per_day',title="songs per day ",y_axis="number of user",x_axis="stddev_songs per day") 
+------+--------+-----+
|userId|     day|count|
+------+--------+-----+
|    10|20181008|   70|
|    10|20181018|  345|
|    10|20181019|  104|
|    10|20181029|   25|
|    10|20181103|   78|
|    10|20181116|    8|
|    10|20181115|   70|
|    10|20181119|   95|
|   100|20181002|  154|
|   100|20181004|  108|
+------+--------+-----+
only showing top 10 rows

+------+--------------------+-----+
|userId|stddev_songs_per_day|Churn|
+------+--------------------+-----+
|100010|  30.203392159808534|    0|
|200002|   53.63558063640294|    0|
|   125|                 0.0|    1|
|   124|  105.56852078318111|    0|
|    51|   115.2379580791352|    1|
|     7|  27.154361988454085|    0|
|    15|  62.535795843443566|    0|
|    54|   82.29950885107391|    1|
|   155|   76.70519865041743|    0|
|100014|  29.244182707373138|    1|
+------+--------------------+-----+
only showing top 10 rows

average songs per week

In [36]:
songs_in_week = user_log\
    .groupby("userId","week")\
    .count()\
    .orderBy("userId",user_log.week.cast("float"))
songs_in_week.show(n=10)

songs_in_week=songs_in_week.groupby("userId")\
    .agg({'count':'avg'})\
    .withColumnRenamed('avg(count)', 'songs_per_week') \
    .orderBy('userId')
#songs_in_day.show(n=10)

songs_in_week=songs_in_week\
    .join(user_Churn,on='userId',how="inner")
songs_in_week.show(10)

plot_hist(df=songs_in_week,param='songs_per_week',title="songs per week ",y_axis="number of user",x_axis="songs per week") 
+------+----+-----+
|userId|week|count|
+------+----+-----+
|    10|  41|   70|
|    10|  42|  449|
|    10|  44|  103|
|    10|  46|   78|
|    10|  47|   95|
|   100|  40|  302|
|   100|  41|  319|
|   100|  42|  240|
|   100|  43|  482|
|   100|  44|  620|
+------+----+-----+
only showing top 10 rows

+------+------------------+-----+
|userId|    songs_per_week|Churn|
+------+------------------+-----+
|100010|             95.25|    0|
|200002|             118.5|    0|
|   125|              11.0|    1|
|   124| 536.1111111111111|    0|
|    51| 821.3333333333334|    1|
|     7|              40.2|    0|
|    15|325.42857142857144|    0|
|    54|             491.0|    1|
|   155|             250.5|    0|
|100014|              77.5|    1|
+------+------------------+-----+
only showing top 10 rows

stddev songs per week

In [37]:
std_songs_in_week = user_log\
    .groupby("userId","week")\
    .count()\
    .orderBy("userId",user_log.week.cast("float"))
std_songs_in_week.show(n=10)

std_songs_in_week=std_songs_in_week.groupby("userId")\
    .agg({'count':'stddev_pop'})\
    .withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_week') \
    .orderBy('userId')
#songs_in_day.show(n=10)

std_songs_in_week=std_songs_in_week\
    .join(user_Churn,on='userId',how="inner")
std_songs_in_week.show(10)

plot_hist(df=std_songs_in_week,param='stddev_songs_per_week',title="stddev_songs per week ",y_axis="number of user",x_axis="stddev_songs per week") 
+------+----+-----+
|userId|week|count|
+------+----+-----+
|    10|  41|   70|
|    10|  42|  449|
|    10|  44|  103|
|    10|  46|   78|
|    10|  47|   95|
|   100|  40|  302|
|   100|  41|  319|
|   100|  42|  240|
|   100|  43|  482|
|   100|  44|  620|
+------+----+-----+
only showing top 10 rows

+------+---------------------+-----+
|userId|stddev_songs_per_week|Churn|
+------+---------------------+-----+
|100010|    62.51149894219463|    0|
|200002|   57.107355042936454|    0|
|   125|                  0.0|    1|
|   124|    314.4563013076678|    0|
|    51|    309.1432605695223|    1|
|     7|   26.483202223296185|    0|
|    15|   152.28343605907762|    0|
|    54|    300.8825114796234|    1|
|   155|    97.32805350976665|    0|
|100014|    51.27133702177075|    1|
+------+---------------------+-----+
only showing top 10 rows

average songs per month

In [38]:
songs_in_month = user_log\
    .groupby("userId","month")\
    .count()\
    .orderBy("userId",user_log.month.cast("float"))
songs_in_month.show(n=10)

songs_in_month=songs_in_month.groupby("userId")\
    .agg({'count':'avg'})\
    .withColumnRenamed('avg(count)', 'songs_per_month') \
    .orderBy('userId')
#songs_in_day.show(n=10)

songs_in_month=songs_in_month\
    .join(user_Churn,on='userId',how="inner")
songs_in_month.show(10)

plot_hist(df=songs_in_month,param='songs_per_month',title="songs per month ",y_axis="number of user",x_axis="songs per month") 
+------+------+-----+
|userId| month|count|
+------+------+-----+
|    10|201810|  544|
|    10|201811|  251|
|   100|201810| 1462|
|   100|201811| 1752|
|100001|201810|  187|
|100002|201810|    5|
|100002|201811|  212|
|100002|201812|    1|
|100003|201810|   78|
|100004|201810|  557|
+------+------+-----+
only showing top 10 rows

+------+---------------+-----+
|userId|songs_per_month|Churn|
+------+---------------+-----+
|100010|          190.5|    0|
|200002|          237.0|    0|
|   125|           11.0|    1|
|   124|         2412.5|    0|
|    51|         2464.0|    1|
|     7|          100.5|    0|
|    15|         1139.0|    0|
|    54|         1718.5|    1|
|   155|         1002.0|    0|
|100014|          155.0|    1|
+------+---------------+-----+
only showing top 10 rows

stddev songs per month

In [39]:
std_songs_in_month = user_log\
    .groupby("userId","month")\
    .count()\
    .orderBy("userId",user_log.month.cast("float"))
std_songs_in_month.show(n=10)

std_songs_in_month=std_songs_in_month.groupby("userId")\
    .agg({'count':'stddev_pop'})\
    .withColumnRenamed('stddev_pop(count)', 'stddev_songs_per_month') \
    .orderBy('userId')
#songs_in_day.show(n=10)

std_songs_in_month=std_songs_in_month\
    .join(user_Churn,on='userId',how="inner")
std_songs_in_month.show(10)

plot_hist(df=std_songs_in_month,param='stddev_songs_per_month',title="stddev_songs per month ",y_axis="number of user",x_axis="stddev_songs per month") 
+------+------+-----+
|userId| month|count|
+------+------+-----+
|    10|201810|  544|
|    10|201811|  251|
|   100|201810| 1462|
|   100|201811| 1752|
|100001|201810|  187|
|100002|201810|    5|
|100002|201811|  212|
|100002|201812|    1|
|100003|201810|   78|
|100004|201810|  557|
+------+------+-----+
only showing top 10 rows

+------+----------------------+-----+
|userId|stddev_songs_per_month|Churn|
+------+----------------------+-----+
|100010|                  26.5|    0|
|200002|                  97.0|    0|
|   125|                   0.0|    1|
|   124|                 205.5|    0|
|    51|                   0.0|    1|
|     7|                  43.5|    0|
|    15|                 194.0|    0|
|    54|                 758.5|    1|
|   155|                   0.0|    0|
|100014|                  79.0|    1|
+------+----------------------+-----+
only showing top 10 rows

Number of songs heard so far

In [40]:
songs=user_log\
    .groupby("userId")\
    .agg({'song':'count'})\
    .withColumnRenamed('count(song)', 'songs heard so far') \
    .orderBy("userId")

songs=songs\
    .join(user_Churn,on='userId',how="inner")
songs.show(10)

plot_hist(df=songs,param='songs heard so far',title="Number of songs heard so far ",y_axis="number of user",x_axis="songs ") 
+------+------------------+-----+
|userId|songs heard so far|Churn|
+------+------------------+-----+
|100010|               275|    0|
|200002|               387|    0|
|   125|                 8|    1|
|   124|              4079|    0|
|    51|              2111|    1|
|     7|               150|    0|
|    15|              1914|    0|
|    54|              2841|    1|
|   155|               820|    0|
|100014|               257|    1|
+------+------------------+-----+
only showing top 10 rows

Number of Thums up /down

In [41]:
Thumbs_Down=user_log\
    .filter(user_log.page == 'Thumbs Down')\
    .select('page','userId')\
    .groupby("userId")\
    .agg({'page':'count'})\
    .withColumnRenamed('count(page)', 'Thumbs Down') \
    .orderBy("userId")

Thumbs_Down=Thumbs_Down\
    .join(user_Churn,on='userId',how="inner")
Thumbs_Down.show(n=10)

plot_hist(df=Thumbs_Down,param='Thumbs Down',title="Number of Thumbs Down ",y_axis="number of user",x_axis="Thumbs Down") 

Thumbs_Up=user_log\
    .filter(user_log.page == 'Thumbs Up')\
    .select('page','userId')\
    .groupby("userId")\
    .agg({'page':'count'})\
    .withColumnRenamed('count(page)', 'Thumbs Up') \
    .orderBy("userId")

Thumbs_Up=Thumbs_Up\
    .join(user_Churn,on='userId',how="inner")
Thumbs_Up.show(10)

plot_hist(df=Thumbs_Up,param='Thumbs Up',title="Number of Thumbs Up ",y_axis="number of user",x_axis="Thumbs Up") 
+------+-----------+-----+
|userId|Thumbs Down|Churn|
+------+-----------+-----+
|100010|          5|    0|
|200002|          6|    0|
|   124|         41|    0|
|    51|         21|    1|
|     7|          1|    0|
|    15|         14|    0|
|    54|         29|    1|
|   155|          3|    0|
|100014|          3|    1|
|   132|         17|    0|
+------+-----------+-----+
only showing top 10 rows

+------+---------+-----+
|userId|Thumbs Up|Churn|
+------+---------+-----+
|100010|       17|    0|
|200002|       21|    0|
|   124|      171|    0|
|    51|      100|    1|
|     7|        7|    0|
|    15|       81|    0|
|    54|      163|    1|
|   155|       58|    0|
|100014|       17|    1|
|   132|       96|    0|
+------+---------+-----+
only showing top 10 rows

Number of playlist

In [42]:
playlist=user_log\
    .filter(user_log.page == 'Add to Playlist')\
    .select('page','userId')\
    .groupby("userId")\
    .agg({'page':'count'})\
    .withColumnRenamed('count(page)', 'playlist') \
    .orderBy("userId")

playlist=playlist\
    .join(user_Churn,on='userId',how="inner")
playlist.show(10)
plot_hist(df=playlist,param='playlist',title="Number of playlist ",y_axis="number of user",x_axis="Playlist") 
+------+--------+-----+
|userId|playlist|Churn|
+------+--------+-----+
|100010|       7|    0|
|200002|       8|    0|
|   124|     118|    0|
|    51|      52|    1|
|     7|       5|    0|
|    15|      59|    0|
|    54|      72|    1|
|   155|      24|    0|
|100014|       7|    1|
|   132|      38|    0|
+------+--------+-----+
only showing top 10 rows

Number of Friend

In [43]:
Friend=user_log\
    .filter(user_log.page == 'Add Friend')\
    .select('page','userId')\
    .groupby("userId")\
    .agg({'page':'count'})\
    .withColumnRenamed('count(page)', 'Friend') \
    .orderBy("userId")
Friend=Friend\
    .join(user_Churn,on='userId',how="inner")
Friend.show(10)
plot_hist(df=Friend,param='Friend',title="Number of Friend ",y_axis="number of user",x_axis="Friend") 
+------+------+-----+
|userId|Friend|Churn|
+------+------+-----+
|100010|     4|    0|
|200002|     4|    0|
|   124|    74|    0|
|    51|    28|    1|
|     7|     1|    0|
|    15|    31|    0|
|    54|    33|    1|
|   155|    11|    0|
|100014|     6|    1|
|   132|    41|    0|
+------+------+-----+
only showing top 10 rows

Total length heard so far

In [44]:
length=user_log\
    .select('length','userId')\
    .groupby("userId")\
    .agg({'length':'sum'})\
    .withColumnRenamed('sum(length)', 'total_lenght') \
    .orderBy("userId")
length=length\
    .join(user_Churn,on='userId',how="inner")
length.show(10)
plot_hist(df=length,param='total_lenght',title="total_lenght",y_axis="number of user",x_axis="total_lenght") 
+------+------------------+-----+
|userId|      total_lenght|Churn|
+------+------------------+-----+
|100010|       66940.89735|    0|
|200002| 94008.87593999997|    0|
|   125|2089.1131000000005|    1|
|   124|1012312.0927900004|    0|
|    51| 523275.8428000001|    1|
|     7|        38034.0871|    0|
|    15|477307.60580999986|    0|
|    54| 711344.9195400004|    1|
|   155|198779.29190000004|    0|
|100014|       67703.47208|    1|
+------+------------------+-----+
only showing top 10 rows

Days from registration

In [45]:
df_s=user_log\
    .select('ts','userId','registration')\
    .groupby("userId")\
    .agg({'ts':'max','registration':'max'})\
    .withColumnRenamed('max(ts)', 'ts') \
    .withColumnRenamed('max(registration)', 'regi') \
    .orderBy("userId")
#df_s.show(2)

regi_length=df_s.select(col('userId'),datediff(from_unixtime(col('ts')/1000),from_unixtime(col('regi')/1000)))\
    .withColumnRenamed('datediff(from_unixtime((ts / 1000), yyyy-MM-dd HH:mm:ss), from_unixtime((regi / 1000), yyyy-MM-dd HH:mm:ss))', 'Day from registration') \

regi_length=regi_length\
    .join(user_Churn,on='userId',how="inner")
length.show(10)
plot_hist(df=regi_length,param='Day from registration',title="Days from registration",y_axis="number of user",x_axis="Days from registration") 
+------+------------------+-----+
|userId|      total_lenght|Churn|
+------+------------------+-----+
|100010|       66940.89735|    0|
|200002| 94008.87593999997|    0|
|   125|2089.1131000000005|    1|
|   124|1012312.0927900004|    0|
|    51| 523275.8428000001|    1|
|     7|        38034.0871|    0|
|    15|477307.60580999986|    0|
|    54| 711344.9195400004|    1|
|   155|198779.29190000004|    0|
|100014|       67703.47208|    1|
+------+------------------+-----+
only showing top 10 rows

Feature Engineering

Once you've familiarized yourself with the data, build out the features you find promising to train your model on. To work with the full dataset, you can follow the following steps.

  • Write a script to extract the necessary features from the smaller subset of data
  • Ensure that your script is scalable, using the best practices discussed in Lesson 3
  • Try your script on the full data set, debugging your script if necessary

If you are working in the classroom workspace, you can just extract features based on the small subset of data contained here. Be sure to transfer over this work to the larger dataset when you work on your Spark cluster.

result (join feature)

The original time series data is used as the future information for each user. Concatenate the data extracted above. Converts the target variable "Churn" to "Target"

In [46]:
result=session_song.select('userId','Churn','songs_per_session')\
    .join(gender.drop('Churn'),on='userId',how="inner")\
    .join(level.drop('Churn'),on='userId',how="inner")\
    .join(songs_in_day.drop('Churn'),on='userId',how="inner")\
    .join(std_songs_in_day.drop('Churn'),on='userId',how="inner")\
    .join(songs_in_week.drop('Churn'),on='userId',how="inner")\
    .join(std_songs_in_week.drop('Churn'),on='userId',how="inner")\
    .join(songs_in_month.drop('Churn'),on='userId',how="inner")\
    .join(std_songs_in_month.drop('Churn'),on='userId',how="inner")\
    .join(songs.drop('Churn'),on='userId',how="inner")\
    .join(Thumbs_Up.drop('Churn'),on='userId',how="inner")\
    .join(Thumbs_Down.drop('Churn'),on='userId',how="inner")\
    .join(playlist.drop('Churn'),on='userId',how="inner")\
    .join(Friend.drop('Churn'),on='userId',how="inner")\
    .join(length.drop('Churn'),on='userId',how="inner")\
    .join(regi_length.drop('Churn'),on='userId',how="inner")\
    .withColumnRenamed("Churn","target")
In [47]:
print(result.printSchema())
result.show(4)
root
 |-- userId: string (nullable = true)
 |-- target: long (nullable = true)
 |-- songs_per_session: double (nullable = true)
 |-- gender: integer (nullable = true)
 |-- level: integer (nullable = true)
 |-- songs_per_day: double (nullable = true)
 |-- stddev_songs_per_day: double (nullable = true)
 |-- songs_per_week: double (nullable = true)
 |-- stddev_songs_per_week: double (nullable = true)
 |-- songs_per_month: double (nullable = true)
 |-- stddev_songs_per_month: double (nullable = true)
 |-- songs heard so far: long (nullable = false)
 |-- Thumbs Up: long (nullable = false)
 |-- Thumbs Down: long (nullable = false)
 |-- playlist: long (nullable = false)
 |-- Friend: long (nullable = false)
 |-- total_lenght: double (nullable = true)
 |-- Day from registration: integer (nullable = true)

None
+------+------+------------------+------+-----+------------------+--------------------+-----------------+---------------------+---------------+----------------------+------------------+---------+-----------+--------+------+------------------+---------------------+
|userId|target| songs_per_session|gender|level|     songs_per_day|stddev_songs_per_day|   songs_per_week|stddev_songs_per_week|songs_per_month|stddev_songs_per_month|songs heard so far|Thumbs Up|Thumbs Down|playlist|Friend|      total_lenght|Day from registration|
+------+------+------------------+------+-----+------------------+--------------------+-----------------+---------------------+---------------+----------------------+------------------+---------+-----------+--------+------+------------------+---------------------+
|100010|     0|39.285714285714285|     0|    1| 54.42857142857143|  30.203392159808534|            95.25|    62.51149894219463|          190.5|                  26.5|               275|       17|          5|       7|     4|       66940.89735|                   55|
|200002|     0|              64.5|     1|    1| 67.71428571428571|   53.63558063640294|            118.5|   57.107355042936454|          237.0|                  97.0|               387|       21|          6|       8|     4| 94008.87593999997|                   70|
|   124|     0|145.67857142857142|     0|    0|146.21212121212122|  105.56852078318111|536.1111111111111|    314.4563013076678|         2412.5|                 205.5|              4079|      171|         41|     118|    74|1012312.0927900004|                  131|
|    51|     1|             211.1|     1|    0|189.53846153846155|   115.2379580791352|821.3333333333334|    309.1432605695223|         2464.0|                   0.0|              2111|      100|         21|      52|    28| 523275.8428000001|                   20|
+------+------+------------------+------+-----+------------------+--------------------+-----------------+---------------------+---------------+----------------------+------------------+---------+-----------+--------+------+------------------+---------------------+
only showing top 4 rows

Modeling

Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.

In [48]:
result=result.drop('userId')
In [49]:
col=result.columns
col.remove('target')
col
Out[49]:
['songs_per_session',
 'gender',
 'level',
 'songs_per_day',
 'stddev_songs_per_day',
 'songs_per_week',
 'stddev_songs_per_week',
 'songs_per_month',
 'stddev_songs_per_month',
 'songs heard so far',
 'Thumbs Up',
 'Thumbs Down',
 'playlist',
 'Friend',
 'total_lenght',
 'Day from registration']
In [50]:
train, test = result.randomSplit([0.7, 0.3], seed=0)
In [51]:
cols=result.columns
assembler = VectorAssembler(inputCols=cols, outputCol="features")
output = assembler.transform(result)
sdf2 = output.select("features")

from pyspark.ml.stat import Correlation

s_corr = Correlation.corr(sdf2, "features", "pearson").head()
#print("Pearson correlation matrix:\n" + str(s_corr[0]))

s_corr_ls = s_corr[0].toArray().tolist()
s_corr_df = spark.createDataFrame(s_corr_ls, cols)
p_corr_df = s_corr_df.select("*").toPandas()
r_index = pd.Series(cols)
p_corr_df = p_corr_df.set_index(r_index)
#print(p_corr_df)

import seaborn as sns
cm = sns.clustermap(p_corr_df, annot=True, figsize=(8, 8), col_cluster=False, row_cluster=False, fmt="1.1f")
cm.cax.set_visible(False)
display()

model

model pipeline

In [52]:
def build_model(classifier, param):
    """
    pipeline for assembler,scaler and crossValidator
    
    input:
        classfier : estimator for CrossValidator
        param     : parameter for estimator
    output:
        pileline model
    
    """
    assembler = VectorAssembler(inputCols=col, outputCol="features")
    scaler = StandardScaler(inputCol="features", outputCol="scaled_features")

    cv = CrossValidator(
        estimator=classifier,
        estimatorParamMaps=param,
        evaluator=MulticlassClassificationEvaluator(labelCol='target', metricName='f1'),
        numFolds=5,
    )
    model = Pipeline(stages=[assembler, scaler, cv])
    
    return model

metrics

that it is unbalanced data set.F1 score is the most suitable evaluation metric.

In [53]:
def metrics(pred):
    """
    print  F1 score,Accuracy,Precision and Recall
    
    input
         pred: prediction after fit and transform

    
    """
    
    evaluator  = MulticlassClassificationEvaluator(labelCol="target", predictionCol="prediction" )
    precision =evaluator.evaluate(pred, {evaluator.metricName:"weightedPrecision"})
    recall =evaluator.evaluate(pred, {evaluator.metricName: "weightedRecall"})
    f1 =evaluator.evaluate(pred, {evaluator.metricName: "f1"})
    accuracy =evaluator.evaluate(pred, {evaluator.metricName: "accuracy"})
    print("F1       : {}".format(f1))
    print("Accuracy : {}".format(accuracy))
    print("Precision: {}".format(precision))
    print("Recall   : {}".format(recall))

Logistic Regression

In [54]:
lr = LogisticRegression(featuresCol="scaled_features", labelCol="target")
param = ParamGridBuilder().build()
model = build_model(lr, param)


fit_model = model.fit(train)
pred = fit_model.transform(test)
#pred.select("prediction").dropDuplicates().collect()
In [55]:
best_model_lr=fit_model.stages[-1].bestModel
In [56]:
feature_coef = best_model_lr.coefficients
feature_coef_df = pd.DataFrame(list(zip(col, feature_coef)), columns=['Feature', 'Coefficient']).sort_values('Coefficient', ascending=False)
feature_coef_df.plot(kind='barh',x='Feature',y='Coefficient',legend=False)
plt.title('Feature importance for Logistic Rgression')
plt.xlabel('Coefficient')
Out[56]:
Text(0.5, 0, 'Coefficient')
In [60]:
metrics(pred)
F1       : 0.8513017356475301
Accuracy : 0.859375
Precision: 0.8488636363636364
Recall   : 0.859375

GBTClassifier

Perform Hyperparameter tuning.
-Maxdepth 2,6
-Maxbin: 10,20
-MaxIter 5,10
As a result of tuning, maxDepth: 2, maxbin: 10, maxIter 5 was selected.

In [65]:
gbt = GBTClassifier(labelCol="target", featuresCol="scaled_features")
param = ParamGridBuilder().addGrid(gbt.maxDepth, [2,6]).addGrid(gbt.maxBins, [10,20]).addGrid(gbt.maxIter, [5,10]).build()
print(param)
model = build_model(gbt, param)

gbt_model = model.fit(train)
pred_gbt = gbt_model.transform(test)
[{Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 10, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 10, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 10}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 20, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 2, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 20, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 10}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 6, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 10, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 6, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 10, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 10}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 6, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 20, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='GBTClassifier_5b48f91c6e61', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.'): 6, Param(parent='GBTClassifier_5b48f91c6e61', name='maxBins', doc='Max number of bins for discretizing continuous features.  Must be >=2 and >= number of categories for any categorical feature.'): 20, Param(parent='GBTClassifier_5b48f91c6e61', name='maxIter', doc='max number of iterations (>= 0).'): 10}]
In [66]:
best_model_gbt=gbt_model.stages[-1].bestModel
featureImportances=best_model_gbt.featureImportances
feature_Importance_df = pd.DataFrame(list(zip(col, featureImportances)), columns=['Feature', 'featureImportances']).sort_values('featureImportances', ascending=True)
feature_Importance_df.plot(kind='barh',x='Feature',y='featureImportances',legend=False)
plt.title('Feature importance for GBTClassifier')
plt.xlabel('Feature importance')
Out[66]:
Text(0.5, 0, 'Feature importance')
In [67]:
best_model_gbt._java_obj.getMaxDepth()
Out[67]:
2
In [68]:
best_model_gbt._java_obj.getMaxBins()
Out[68]:
10
In [69]:
best_model_gbt._java_obj.getMaxIter()
Out[69]:
5
In [71]:
metrics(pred_gbt)
F1       : 0.8715553789083201
Accuracy : 0.890625
Precision: 0.9036016949152542
Recall   : 0.890625

Random Forest

Perform Hyperparameter tuning. -Maxdepth 2,6 -Maxbin: 10,20 -numTrees 5,10,50 As a result of tuning, maxDepth: 6, maxbin: 10, numTrees 50 was selected.

In [85]:
rf = RandomForestClassifier(labelCol="target", featuresCol="scaled_features")
param = ParamGridBuilder().addGrid(rf.maxDepth, [2,6]).addGrid(rf.maxBins, [10,20]).addGrid(rf.numTrees, [5, 20, 50]).build()
model = build_model(rf, param)

rf_model = model.fit(train)
pred_rf = rf_model.transform(test)
In [86]:
best_model_rf=rf_model.stages[-1].bestModel
featureImportances=best_model_rf.featureImportances
feature_Importance_df = pd.DataFrame(list(zip(col, featureImportances)), columns=['Feature', 'featureImportances']).sort_values('featureImportances', ascending=True)
feature_Importance_df.plot(kind='barh',x='Feature',y='featureImportances',legend=False)
plt.title('Feature importance for Random Forest')
plt.xlabel('Feature importance')
Out[86]:
Text(0.5, 0, 'Feature importance')
In [87]:
best_model_rf._java_obj.getMaxDepth()
Out[87]:
6
In [91]:
best_model_rf._java_obj.getNumTrees()
Out[91]:
50
In [82]:
best_model_rf._java_obj.
In [89]:
best_model_rf.featureImportances.norm
Out[89]:
<bound method SparseVector.norm of SparseVector(16, {0: 0.0795, 1: 0.0073, 2: 0.0126, 3: 0.0522, 4: 0.0574, 5: 0.0448, 6: 0.1013, 7: 0.0366, 8: 0.1612, 9: 0.0286, 10: 0.0676, 11: 0.0561, 12: 0.0447, 13: 0.0405, 14: 0.0302, 15: 0.1794})>
In [90]:
metrics(pred_rf)
F1       : 0.7537593984962406
Accuracy : 0.8125
Precision: 0.7620967741935484
Recall   : 0.8125

Conclusion

Model selection

Try the following three methods supported by PysparkML.

1.Logistic Regression

when the objective variable and the design variable have a linear relationship, it is an excellent method in terms of calculation cost and model readability. Check as a base model. On the other hand, it should be noted that the prediction accuracy deteriorates due to Multicol linearity when the variables are highly correlated.

2.Random Forest/3.GBT

Assuming the nonlinearity of the objective variable and the design variable, select a tree-based random forest boosting method that can support some readability. This method can also be expected to be sparse so that the classifier can extract valid variables.

Model improvement
1.Model tune
A grid-based search for maxDepth, maxBins, and maxIter parameters for a Tree-based model. The parameter with the best score was adopted.

2.robustness
After the train and test data were divided into 7:3, the train data was cross-validated five times and then the average score was adopted. This is an effective means for improving robustness when the amount of data is relatively small.

result

We conducted three models and obtained the following results. When conducting a campaign to Churn users, we would like to reduce the number of detections that Churn misses, while over-detection will distribute useless campaigns. The campaign can be conducted efficiently by referring to the F1 Score, which is the harmonic average of Recall and Precision.

model f1 score accuracy recall Precision
Logistic Regression 0.85 0.85 0.84 0.85
GBTClassifier 0.87 0.89 0.90 0.89
Random forest 0.75 0.81 0.76 0.81

Future considerations

Check the Featue Impottance for a GBT model. You can see that the following are variables that contributed to the classification accuracy.

  • Standard deviation of songs play per month
  • days from registration
  • Standard deviation of songs play per day
  • song per session

conclusion
A classifier on a distributed platform was created by suppressing the characteristics of users who tend to cancel. Prediction with a linear model is a difficult event, and the result is that a Tree-based classifier is suitable. We selected important Featurer, but in order to understand the actual market trends, it is necessary to use larger scale data, and we can imagine that the prediction accuracy and features may change. A large-scale distributed infrastructure is indispensable for developing services based on vast amounts of customer data.

Final Steps

Clean up your code, adding comments and renaming variables to make the code easier to read and maintain. Refer to the Spark Project Overview page and Data Scientist Capstone Project Rubric to make sure you are including all components of the capstone project and meet all expectations. Remember, this includes thorough documentation in a README file in a Github repository, as well as a web app or blog post.

In [ ]:
 
In [ ]: